package cn.itcast.up.mldemo

import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{MinMaxScaler, StandardScaler}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * Author itcast
 * Date 2020/3/22 15:26
 * Desc 演示KMeans入门案例
 * 使用SparkMLlib提供的KMeans聚类算法对Iris鸢尾花数据进行聚类
 * 注意:KMeans聚类算法本身是一个无监督学习算法
 * 而Iris数据集有4个特征列(会被使用到)和1个标签列(聚类时不会使用,最后方便我们查看聚类效果时可以使用)
 */
object G_Demo_KMeans {
  def main(args: Array[String]): Unit = {
    //0.准备环境和数据
    val spark: SparkSession = SparkSession.builder()
      .appName("ml")
      .master("local[*]")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._
    //SparkMLlib支持直接读取libsvm格式的数据并自动解析
    val irisDF: DataFrame = spark.read.format("libsvm").load("file:///D:\\授课\\191021-35\\用户画像\\day07\\data\\ml\\iris_KMeans.libsvm")
    irisDF.show(false)
    irisDF.printSchema()
    /*
 +-----+-------------------------------+
|label|features                       |
+-----+-------------------------------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
     */

    //通过打印Spark读取的libsvm格式的数据,发现
    //数据已经做了初步的处理,如特征向量化
    //但是我们还需要做进一步的处理,如数据归一化/标准化

    //1.特征工程-数据归一化/标准化
    //MinMaxScaler最小最大归一化:将数据缩放到[0,1],归一化之后的数据= (原数据-该列最小值)/(该列最大值-该列最小值)
    //StandardScaler标准化归一化工具:将数据缩放到标准正态分布上,0均值1方差/标准差,数据如果是符合正态分布可以使用该方法
    val scalerDF: DataFrame = new MinMaxScaler()//或者使用StandardScaler
      .setInputCol("features")
      .setOutputCol("scaler_features")
      .fit(irisDF)
      .transform(irisDF)
    scalerDF.show(false)
    /*
 +-----+-------------------------------+---------------------------------------------------------------------------------+
|label|features                       |scaler_features                                                                  |
+-----+-------------------------------+---------------------------------------------------------------------------------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|[0.22222222222222213,0.6249999999999999,0.06779661016949151,0.04166666666666667] |
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|[0.1666666666666668,0.41666666666666663,0.06779661016949151,0.04166666666666667] |
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|[0.11111111111111119,0.5,0.05084745762711865,0.04166666666666667]                |
     */

    //2.训练KMeans算法模型
    val model: KMeansModel = new KMeans()
      .setK(3) //设置k值,也就是聚为几类
      .setSeed(100) //设置随机种子,保证每次执行使用的随机种子一样
      .setMaxIter(20) //设置最大迭代次数
      .setFeaturesCol("scaler_features") //设置特征列是哪个,应该使用归一化之后的
      .setPredictionCol("predictClusterIndex") //设置预测列名称,注意:预测出来的是聚类中心的索引编号
      .fit(scalerDF)//填充/适合/训练

    //3.使用模型进行预测
    val result: DataFrame = model.transform(scalerDF)
    //注意:KMeans直接使用所有数据聚类即可

    //4.查看聚类效果
    result.show(false)
    result.groupBy("label","predictClusterIndex").count().show(false)
    /*
+-----+-------------------------------+---------------------------------------------------------------------------------+-------------------+
|label|features                       |scaler_features                                                                  |predictClusterIndex|
+-----+-------------------------------+---------------------------------------------------------------------------------+-------------------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|[0.22222222222222213,0.6249999999999999,0.06779661016949151,0.04166666666666667] |0                  |
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|[0.1666666666666668,0.41666666666666663,0.06779661016949151,0.04166666666666667] |0                  |
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|[0.11111111111111119,0.5,0.05084745762711865,0.04166666666666667]                |0                  |
|1.0  |(4,[0,1,2,3],[4.6,3.1,1.5,0.2])|[0.08333333333333327,0.4583333333333333,0.0847457627118644,0.04166666666666667]  |0                  |
|1.0  |(4,[0,1,2,3],[5.0,3.6,1.4,0.2])|[0.19444444444444448,0.6666666666666666,0.06779661016949151,0.04166666666666667] |0                  |
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.7,0.4])|[0.30555555555555564,0.7916666666666665,0.11864406779661016,0.12500000000000003] |0                  |
|1.0  |(4,[0,1,2,3],[4.6,3.4,1.4,0.3])|[0.08333333333333327,0.5833333333333333,0.06779661016949151,0.08333333333333333] |0                  |
|1.0  |(4,[0,1,2,3],[5.0,3.4,1.5,0.2])|[0.19444444444444448,0.5833333333333333,0.0847457627118644,0.04166666666666667]  |0                  |
|1.0  |(4,[0,1,2,3],[4.4,2.9,1.4,0.2])|[0.027777777777777922,0.3749999999999999,0.06779661016949151,0.04166666666666667]|0                  |
|1.0  |(4,[0,1,2,3],[4.9,3.1,1.5,0.1])|[0.1666666666666668,0.4583333333333333,0.0847457627118644,0.0]                   |0                  |
|1.0  |(4,[0,1,2,3],[5.4,3.7,1.5,0.2])|[0.30555555555555564,0.7083333333333333,0.0847457627118644,0.04166666666666667]  |0                  |
|1.0  |(4,[0,1,2,3],[4.8,3.4,1.6,0.2])|[0.13888888888888887,0.5833333333333333,0.1016949152542373,0.04166666666666667]  |0                  |
|1.0  |(4,[0,1,2,3],[4.8,3.0,1.4,0.1])|[0.13888888888888887,0.41666666666666663,0.06779661016949151,0.0]                |0                  |
|1.0  |(4,[0,1,2,3],[4.3,3.0,1.1,0.1])|[0.0,0.41666666666666663,0.016949152542372895,0.0]                               |0                  |
|1.0  |(4,[0,1,2,3],[5.8,4.0,1.2,0.2])|[0.41666666666666663,0.8333333333333333,0.033898305084745756,0.04166666666666667]|0                  |
|1.0  |(4,[0,1,2,3],[5.7,4.4,1.5,0.4])|[0.38888888888888895,1.0,0.0847457627118644,0.12500000000000003]                 |0                  |
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.3,0.4])|[0.30555555555555564,0.7916666666666665,0.05084745762711865,0.12500000000000003] |0                  |
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.3])|[0.22222222222222213,0.6249999999999999,0.06779661016949151,0.08333333333333333] |0                  |
|1.0  |(4,[0,1,2,3],[5.7,3.8,1.7,0.3])|[0.38888888888888895,0.7499999999999998,0.11864406779661016,0.08333333333333333] |0                  |
|1.0  |(4,[0,1,2,3],[5.1,3.8,1.5,0.3])|[0.22222222222222213,0.7499999999999998,0.0847457627118644,0.08333333333333333]  |0                  |
+-----+-------------------------------+---------------------------------------------------------------------------------+-------------------+
only showing top 20 rows

+-----+-------------------+-----+
|label|predictClusterIndex|count|
+-----+-------------------+-----+
|2.0  |1                  |40   |
|1.0  |0                  |50   |
|2.0  |2                  |10   |
|3.0  |1                  |8    |
|3.0  |2                  |42   |
+-----+-------------------+-----+

     */
  }
}
